{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Fine-tuning TapasForQuestionAnswering on SQA.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true,
      "mount_file_id": "1zMW-D2kYrpDA-cvpNJ-ctGD-tDXWebZa",
      "authorship_tag": "ABX9TyOiiRY4UpoKouOKLEeuEpVF",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "768723f1af4c4bf497064a9796567382": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_82bfe5d88c5546358de33952e42221a3",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_c6aadbda246d47fcad9660c2f57926a0",
              "IPY_MODEL_cd910003940b498e94db9746091e4c3c"
            ]
          }
        },
        "82bfe5d88c5546358de33952e42221a3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "c6aadbda246d47fcad9660c2f57926a0": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_a11bc2e0d34543d1bb0ab708e0cc1a67",
            "_dom_classes": [],
            "description": "Downloading: 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1432,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1432,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_d59737fa0a424a90ba02110e4cf1b639"
          }
        },
        "cd910003940b498e94db9746091e4c3c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_d534c2ebdbb144efa59500fcadf6e118",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 1.43k/1.43k [00:00&lt;00:00, 4.46kB/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_f6fef10b29f74c458d05c2cbda46bf4a"
          }
        },
        "a11bc2e0d34543d1bb0ab708e0cc1a67": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "d59737fa0a424a90ba02110e4cf1b639": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "d534c2ebdbb144efa59500fcadf6e118": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "f6fef10b29f74c458d05c2cbda46bf4a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "695d0c44ebbe4a55a17c735645c26c82": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_7e3cd6c49143436bad3d84be7ca2cf79",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_be6159a5be0945629a517297df1170b8",
              "IPY_MODEL_b889ffa28a1949e9ac46e205ee582688"
            ]
          }
        },
        "7e3cd6c49143436bad3d84be7ca2cf79": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "be6159a5be0945629a517297df1170b8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_3c95dd6249784fe6a2a466737e5f5866",
            "_dom_classes": [],
            "description": "Downloading: 100%",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 442768791,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 442768791,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_cf9382ee988f4787a88340d3b2af95f3"
          }
        },
        "b889ffa28a1949e9ac46e205ee582688": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_45fdca8a04f94b43b021c590181e8a48",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 443M/443M [00:06&lt;00:00, 66.0MB/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_3bad5ab136084c2e92b55153828a403f"
          }
        },
        "3c95dd6249784fe6a2a466737e5f5866": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "cf9382ee988f4787a88340d3b2af95f3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "45fdca8a04f94b43b021c590181e8a48": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "3bad5ab136084c2e92b55153828a403f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l5Ds1ZM41KC9"
      },
      "source": [
        "## Introduction: TAPAS\r\n",
        "\r\n",
        "* Original TAPAS paper (ACL 2020): https://www.aclweb.org/anthology/2020.acl-main.398/\r\n",
        "* Follow-up paper on intermediate pre-training (EMMNLP Findings 2020): https://www.aclweb.org/anthology/2020.findings-emnlp.27/\r\n",
        "* Original Github repository: https://github.com/google-research/tapas\r\n",
        "* Blog post: https://ai.googleblog.com/2020/04/using-neural-networks-to-find-answers.html\r\n",
        "\r\n",
        "TAPAS is an algorithm that (among other tasks) can answer questions about tabular data. It is essentially a BERT model with relative position embeddings and additional token type ids that encode tabular structure, and 2 classification heads on top: one for **cell selection** and one for (optionally) performing an **aggregation** among selected cells (such as summing or counting).\r\n",
        "\r\n",
        "Similar to BERT, the base `TapasModel` is pre-trained using the masked language modeling (MLM) objective on a large collection of tables from Wikipedia and associated texts. In addition, the authors further pre-trained the model on an second task (table entailment) to increase the numerical reasoning capabilities of TAPAS (as explained in the follow-up paper), which further improves performance on downstream tasks. \r\n",
        "\r\n",
        "In this notebook, we are going to fine-tune `TapasForQuestionAnswering` on [Sequential Question Answering (SQA)](https://www.microsoft.com/en-us/research/publication/search-based-neural-structured-learning-sequential-question-answering/), a dataset built by Microsoft Research which deals with asking questions related to a table in a **conversational set-up**. We are going to do so as in the original paper, by adding a randomly initialized cell selection head on top of the pre-trained base model (note that SQA does not have questions that involve aggregation and hence no aggregation head), and then fine-tuning them altogether.\r\n",
        "\r\n",
        "First, we install both the Transformers library as well as the dependency on [`torch-scatter`](https://github.com/rusty1s/pytorch_scatter), which the model requires."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "MUMrt5Ow_PEA",
        "outputId": "eda7d53e-9846-4941-ed72-ce84f495469f"
      },
      "source": [
        "! rm -r transformers\r\n",
        "! git clone https://github.com/huggingface/transformers.git\r\n",
        "! cd transformers\r\n",
        "! pip install ./transformers"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "rm: cannot remove 'transformers': No such file or directory\n",
            "Cloning into 'transformers'...\n",
            "remote: Enumerating objects: 32, done.\u001b[K\n",
            "remote: Counting objects: 100% (32/32), done.\u001b[K\n",
            "remote: Compressing objects: 100% (25/25), done.\u001b[K\n",
            "remote: Total 56845 (delta 9), reused 4 (delta 2), pack-reused 56813\u001b[K\n",
            "Receiving objects: 100% (56845/56845), 42.37 MiB | 30.23 MiB/s, done.\n",
            "Resolving deltas: 100% (39845/39845), done.\n",
            "Processing ./transformers\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "    Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers==4.1.0.dev0) (3.0.12)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers==4.1.0.dev0) (4.41.1)\n",
            "Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers==4.1.0.dev0) (0.8)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers==4.1.0.dev0) (1.18.5)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers==4.1.0.dev0) (2.23.0)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers==4.1.0.dev0) (2019.12.20)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers==4.1.0.dev0) (20.7)\n",
            "Collecting sacremoses\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)\n",
            "\u001b[K     |████████████████████████████████| 890kB 14.0MB/s \n",
            "\u001b[?25hCollecting tokenizers==0.9.4\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/0f/1c/e789a8b12e28be5bc1ce2156cf87cb522b379be9cadc7ad8091a4cc107c4/tokenizers-0.9.4-cp36-cp36m-manylinux2010_x86_64.whl (2.9MB)\n",
            "\u001b[K     |████████████████████████████████| 2.9MB 57.7MB/s \n",
            "\u001b[?25hRequirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==4.1.0.dev0) (3.0.4)\n",
            "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==4.1.0.dev0) (2.10)\n",
            "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==4.1.0.dev0) (1.24.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers==4.1.0.dev0) (2020.12.5)\n",
            "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers==4.1.0.dev0) (2.4.7)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==4.1.0.dev0) (1.15.0)\n",
            "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==4.1.0.dev0) (7.1.2)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==4.1.0.dev0) (0.17.0)\n",
            "Building wheels for collected packages: transformers\n",
            "  Building wheel for transformers (PEP 517) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for transformers: filename=transformers-4.1.0.dev0-cp36-none-any.whl size=1507800 sha256=559a0eba87c03766368f6b4d9d2509544a972e6f70c035ab19ebd841b4fa46fd\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-3k1te24y/wheels/23/19/dd/2561a4e47240cf6b307729d58e56f8077dd0c698f5992216cf\n",
            "Successfully built transformers\n",
            "Building wheels for collected packages: sacremoses\n",
            "  Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893261 sha256=07011c2c7aebdb1168e87ea5c3569b5600083b5291f6f38c90adeafb3851d2e2\n",
            "  Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45\n",
            "Successfully built sacremoses\n",
            "Installing collected packages: sacremoses, tokenizers, transformers\n",
            "Successfully installed sacremoses-0.0.43 tokenizers-0.9.4 transformers-4.1.0.dev0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gx4u09iTyRjY",
        "outputId": "e4cd9f4b-7d8d-4b47-e8b2-304b921dba98"
      },
      "source": [
        "! pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Looking in links: https://pytorch-geometric.com/whl/torch-1.7.0.html\n",
            "Collecting torch-scatter==latest+cu101\n",
            "\u001b[?25l  Downloading https://pytorch-geometric.com/whl/torch-1.7.0/torch_scatter-latest%2Bcu101-cp36-cp36m-linux_x86_64.whl (11.9MB)\n",
            "\u001b[K     |████████████████████████████████| 11.9MB 41.9MB/s \n",
            "\u001b[?25hInstalling collected packages: torch-scatter\n",
            "Successfully installed torch-scatter-2.0.5\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BSZfmBt0meYm"
      },
      "source": [
        "We also install a small portion from the SQA training dataset, for demonstration purposes. This is a TSV file containing table-question pairs. Besides this, we also download the `table_csv` directory, which contains the actual tabular data.\r\n",
        "\r\n",
        "Note that you can download the entire SQA dataset on the [official website](https://www.microsoft.com/en-us/download/details.aspx?id=54253)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wsuwgDEU4J_f"
      },
      "source": [
        "import requests, zipfile, io\r\n",
        "import os\r\n",
        "\r\n",
        "def download_files(dir_name):\r\n",
        "  if not os.path.exists(dir_name): \r\n",
        "    # 28 training examples from the SQA training set + table csv data\r\n",
        "    urls = [\"https://www.dropbox.com/s/2p6ez9xro357i63/sqa_train_set_28_examples.zip?dl=1\",\r\n",
        "            \"https://www.dropbox.com/s/abhum8ssuow87h6/table_csv.zip?dl=1\"\r\n",
        "    ]\r\n",
        "    for url in urls:\r\n",
        "      r = requests.get(url)\r\n",
        "      z = zipfile.ZipFile(io.BytesIO(r.content))\r\n",
        "      z.extractall()\r\n",
        "\r\n",
        "dir_name = \"sqa_data\"\r\n",
        "download_files(dir_name)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EPrYJOn81f0D"
      },
      "source": [
        "## Prepare the data \r\n",
        "\r\n",
        "Let's look at the first few rows of the dataset:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 279
        },
        "id": "2X27wyd805D8",
        "outputId": "7ccfd32c-e8dd-4fec-c044-d7d8de8dd578"
      },
      "source": [
        "import pandas as pd\r\n",
        "\r\n",
        "data = pd.read_excel(\"sqa_train_set_28_examples.xlsx\")\r\n",
        "data.head()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>id</th>\n",
              "      <th>annotator</th>\n",
              "      <th>position</th>\n",
              "      <th>question</th>\n",
              "      <th>table_file</th>\n",
              "      <th>answer_coordinates</th>\n",
              "      <th>answer_text</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>where are the players from?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>['(0, 4)', '(1, 4)', '(2, 4)', '(3, 4)', '(4, ...</td>\n",
              "      <td>['Louisiana State University', 'Valley HS (Las...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>which player went to louisiana state university?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>['(0, 1)']</td>\n",
              "      <td>['Ben McDonald']</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>who are the players?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4, ...</td>\n",
              "      <td>['Ben McDonald', 'Tyler Houston', 'Roger Salke...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>1</td>\n",
              "      <td>1</td>\n",
              "      <td>which ones are in the top 26 picks?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>['(0, 1)', '(1, 1)', '(2, 1)', '(3, 1)', '(4, ...</td>\n",
              "      <td>['Ben McDonald', 'Tyler Houston', 'Roger Salke...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>1</td>\n",
              "      <td>2</td>\n",
              "      <td>and of those, who is from louisiana state univ...</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>['(0, 1)']</td>\n",
              "      <td>['Ben McDonald']</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "       id  ...                                        answer_text\n",
              "0  nt-639  ...  ['Louisiana State University', 'Valley HS (Las...\n",
              "1  nt-639  ...                                   ['Ben McDonald']\n",
              "2  nt-639  ...  ['Ben McDonald', 'Tyler Houston', 'Roger Salke...\n",
              "3  nt-639  ...  ['Ben McDonald', 'Tyler Houston', 'Roger Salke...\n",
              "4  nt-639  ...                                   ['Ben McDonald']\n",
              "\n",
              "[5 rows x 7 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OMJ4dNBV1oj6"
      },
      "source": [
        "As you can see, each row corresponds to a question related to a table. \r\n",
        "* The `position` column identifies whether the question is the first, second, ... in a sequence of questions related to a table. \r\n",
        "* The `table_file` column identifies the name of the table file, which refers to a CSV file in the `table_csv` directory.\r\n",
        "* The `answer_coordinates` and `answer_text` columns indicate the answer to the question. The `answer_coordinates` is a list of tuples, each tuple being a (row_index, column_index) pair. The `answer_text` column is a list of strings, indicating the cell values.\r\n",
        "\r\n",
        "However, the `answer_coordinates` and `answer_text` columns are currently not recognized as real Python lists of Python tuples and strings respectively. Let's do that first using the `.literal_eval()`function of the `ast` module:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 511
        },
        "id": "BAovAs5s1k10",
        "outputId": "0849a829-4ae8-43e9-e138-177fa14e3e36"
      },
      "source": [
        "import ast\r\n",
        "\r\n",
        "def _parse_answer_coordinates(answer_coordinate_str):\r\n",
        "  \"\"\"Parses the answer_coordinates of a question.\r\n",
        "  Args:\r\n",
        "    answer_coordinate_str: A string representation of a Python list of tuple\r\n",
        "      strings.\r\n",
        "      For example: \"['(1, 4)','(1, 3)', ...]\"\r\n",
        "  \"\"\"\r\n",
        "\r\n",
        "  try:\r\n",
        "    answer_coordinates = []\r\n",
        "    # make a list of strings\r\n",
        "    coords = ast.literal_eval(answer_coordinate_str)\r\n",
        "    # parse each string as a tuple\r\n",
        "    for row_index, column_index in sorted(\r\n",
        "        ast.literal_eval(coord) for coord in coords):\r\n",
        "      answer_coordinates.append((row_index, column_index))\r\n",
        "  except SyntaxError:\r\n",
        "    raise ValueError('Unable to evaluate %s' % answer_coordinate_str)\r\n",
        "  \r\n",
        "  return answer_coordinates\r\n",
        "\r\n",
        "\r\n",
        "def _parse_answer_text(answer_text):\r\n",
        "  \"\"\"Populates the answer_texts field of `answer` by parsing `answer_text`.\r\n",
        "  Args:\r\n",
        "    answer_text: A string representation of a Python list of strings.\r\n",
        "      For example: \"[u'test', u'hello', ...]\"\r\n",
        "    answer: an Answer object.\r\n",
        "  \"\"\"\r\n",
        "  try:\r\n",
        "    answer = []\r\n",
        "    for value in ast.literal_eval(answer_text):\r\n",
        "      answer.append(value)\r\n",
        "  except SyntaxError:\r\n",
        "    raise ValueError('Unable to evaluate %s' % answer_text)\r\n",
        "\r\n",
        "  return answer\r\n",
        "\r\n",
        "data['answer_coordinates'] = data['answer_coordinates'].apply(lambda coords_str: _parse_answer_coordinates(coords_str))\r\n",
        "data['answer_text'] = data['answer_text'].apply(lambda txt: _parse_answer_text(txt))\r\n",
        "\r\n",
        "data.head(10)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>id</th>\n",
              "      <th>annotator</th>\n",
              "      <th>position</th>\n",
              "      <th>question</th>\n",
              "      <th>table_file</th>\n",
              "      <th>answer_coordinates</th>\n",
              "      <th>answer_text</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>where are the players from?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[(0, 4), (1, 4), (2, 4), (3, 4), (4, 4), (5, 4...</td>\n",
              "      <td>[Louisiana State University, Valley HS (Las Ve...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>which player went to louisiana state university?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[(0, 1)]</td>\n",
              "      <td>[Ben McDonald]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>who are the players?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1...</td>\n",
              "      <td>[Ben McDonald, Tyler Houston, Roger Salkeld, J...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>1</td>\n",
              "      <td>1</td>\n",
              "      <td>which ones are in the top 26 picks?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1...</td>\n",
              "      <td>[Ben McDonald, Tyler Houston, Roger Salkeld, J...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>1</td>\n",
              "      <td>2</td>\n",
              "      <td>and of those, who is from louisiana state univ...</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[(0, 1)]</td>\n",
              "      <td>[Ben McDonald]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>2</td>\n",
              "      <td>0</td>\n",
              "      <td>who are the players in the top 26?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1...</td>\n",
              "      <td>[Ben McDonald, Tyler Houston, Roger Salkeld, J...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>2</td>\n",
              "      <td>1</td>\n",
              "      <td>of those, which one was from louisiana state u...</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[(0, 1)]</td>\n",
              "      <td>[Ben McDonald]</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>nt-11649</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>what are all the names of the teams?</td>\n",
              "      <td>table_csv/204_135.csv</td>\n",
              "      <td>[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1...</td>\n",
              "      <td>[Cordoba CF, CD Malaga, Granada CF, UD Las Pal...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>nt-11649</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>of these, which teams had any losses?</td>\n",
              "      <td>table_csv/204_135.csv</td>\n",
              "      <td>[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1...</td>\n",
              "      <td>[Cordoba CF, CD Malaga, Granada CF, UD Las Pal...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9</th>\n",
              "      <td>nt-11649</td>\n",
              "      <td>0</td>\n",
              "      <td>2</td>\n",
              "      <td>of these teams, which had more than 21 losses?</td>\n",
              "      <td>table_csv/204_135.csv</td>\n",
              "      <td>[(15, 1)]</td>\n",
              "      <td>[CD Villarrobledo]</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "         id  ...                                        answer_text\n",
              "0    nt-639  ...  [Louisiana State University, Valley HS (Las Ve...\n",
              "1    nt-639  ...                                     [Ben McDonald]\n",
              "2    nt-639  ...  [Ben McDonald, Tyler Houston, Roger Salkeld, J...\n",
              "3    nt-639  ...  [Ben McDonald, Tyler Houston, Roger Salkeld, J...\n",
              "4    nt-639  ...                                     [Ben McDonald]\n",
              "5    nt-639  ...  [Ben McDonald, Tyler Houston, Roger Salkeld, J...\n",
              "6    nt-639  ...                                     [Ben McDonald]\n",
              "7  nt-11649  ...  [Cordoba CF, CD Malaga, Granada CF, UD Las Pal...\n",
              "8  nt-11649  ...  [Cordoba CF, CD Malaga, Granada CF, UD Las Pal...\n",
              "9  nt-11649  ...                                 [CD Villarrobledo]\n",
              "\n",
              "[10 rows x 7 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "X7FYPpdW5dY4"
      },
      "source": [
        "Let's create a new dataframe that groups questions which are asked in a sequence related to the table. We can do this by adding a `sequence_id` column, which is a combination of the `id` and `annotator` columns:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 279
        },
        "id": "O1Quo0FL7h9-",
        "outputId": "5223d575-b86d-41e6-b23a-6071b3048211"
      },
      "source": [
        "def get_sequence_id(example_id, annotator):\r\n",
        "  if \"-\" in str(annotator):\r\n",
        "    raise ValueError('\"-\" not allowed in annotator.')\r\n",
        "  return f\"{example_id}-{annotator}\"\r\n",
        "\r\n",
        "data['sequence_id'] = data.apply(lambda x: get_sequence_id(x.id, x.annotator), axis=1)\r\n",
        "data.head()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>id</th>\n",
              "      <th>annotator</th>\n",
              "      <th>position</th>\n",
              "      <th>question</th>\n",
              "      <th>table_file</th>\n",
              "      <th>answer_coordinates</th>\n",
              "      <th>answer_text</th>\n",
              "      <th>sequence_id</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>0</td>\n",
              "      <td>0</td>\n",
              "      <td>where are the players from?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[(0, 4), (1, 4), (2, 4), (3, 4), (4, 4), (5, 4...</td>\n",
              "      <td>[Louisiana State University, Valley HS (Las Ve...</td>\n",
              "      <td>nt-639-0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>0</td>\n",
              "      <td>1</td>\n",
              "      <td>which player went to louisiana state university?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[(0, 1)]</td>\n",
              "      <td>[Ben McDonald]</td>\n",
              "      <td>nt-639-0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>1</td>\n",
              "      <td>0</td>\n",
              "      <td>who are the players?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1...</td>\n",
              "      <td>[Ben McDonald, Tyler Houston, Roger Salkeld, J...</td>\n",
              "      <td>nt-639-1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>1</td>\n",
              "      <td>1</td>\n",
              "      <td>which ones are in the top 26 picks?</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1...</td>\n",
              "      <td>[Ben McDonald, Tyler Houston, Roger Salkeld, J...</td>\n",
              "      <td>nt-639-1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>nt-639</td>\n",
              "      <td>1</td>\n",
              "      <td>2</td>\n",
              "      <td>and of those, who is from louisiana state univ...</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[(0, 1)]</td>\n",
              "      <td>[Ben McDonald]</td>\n",
              "      <td>nt-639-1</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "       id  ...  sequence_id\n",
              "0  nt-639  ...     nt-639-0\n",
              "1  nt-639  ...     nt-639-0\n",
              "2  nt-639  ...     nt-639-1\n",
              "3  nt-639  ...     nt-639-1\n",
              "4  nt-639  ...     nt-639-1\n",
              "\n",
              "[5 rows x 8 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 541
        },
        "id": "-uPpds5D762B",
        "outputId": "38aa6f13-2cc7-4d96-b8b3-a510288bfca2"
      },
      "source": [
        "# let's group table-question pairs by sequence id, and remove some columns we don't need \r\n",
        "grouped = data.groupby(by='sequence_id').agg(lambda x: x.tolist())\r\n",
        "grouped = grouped.drop(columns=['id', 'annotator', 'position'])\r\n",
        "grouped['table_file'] = grouped['table_file'].apply(lambda x: x[0])\r\n",
        "grouped.head(10)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>question</th>\n",
              "      <th>table_file</th>\n",
              "      <th>answer_coordinates</th>\n",
              "      <th>answer_text</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>sequence_id</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>ns-1292-0</th>\n",
              "      <td>[who are all the athletes?, where are they fro...</td>\n",
              "      <td>table_csv/204_521.csv</td>\n",
              "      <td>[[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, ...</td>\n",
              "      <td>[[Tommy Green, Janis Dalins, Ugo Frigerio, Kar...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>nt-10730-0</th>\n",
              "      <td>[what was the production numbers of each revol...</td>\n",
              "      <td>table_csv/203_253.csv</td>\n",
              "      <td>[[(0, 4), (1, 4), (2, 4), (3, 4), (4, 4), (5, ...</td>\n",
              "      <td>[[1,900 (estimated), 14,500 (estimated), 6,000...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>nt-10730-1</th>\n",
              "      <td>[what three revolver models had the least amou...</td>\n",
              "      <td>table_csv/203_253.csv</td>\n",
              "      <td>[[(0, 0), (6, 0), (7, 0)], [(0, 0)]]</td>\n",
              "      <td>[[Remington-Beals Army Model Revolver, New Mod...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>nt-10730-2</th>\n",
              "      <td>[what are all of the remington models?, how ma...</td>\n",
              "      <td>table_csv/203_253.csv</td>\n",
              "      <td>[[(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, ...</td>\n",
              "      <td>[[Remington-Beals Army Model Revolver, Remingt...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>nt-11649-0</th>\n",
              "      <td>[what are all the names of the teams?, of thes...</td>\n",
              "      <td>table_csv/204_135.csv</td>\n",
              "      <td>[[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, ...</td>\n",
              "      <td>[[Cordoba CF, CD Malaga, Granada CF, UD Las Pa...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>nt-11649-1</th>\n",
              "      <td>[what are the losses?, what team had more than...</td>\n",
              "      <td>table_csv/204_135.csv</td>\n",
              "      <td>[[(0, 6), (1, 6), (2, 6), (3, 6), (4, 6), (5, ...</td>\n",
              "      <td>[[6, 6, 9, 10, 10, 12, 12, 11, 13, 14, 15, 14,...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>nt-11649-2</th>\n",
              "      <td>[what were all the teams?, what were the loss ...</td>\n",
              "      <td>table_csv/204_135.csv</td>\n",
              "      <td>[[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, ...</td>\n",
              "      <td>[[Cordoba CF, CD Malaga, Granada CF, UD Las Pa...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>nt-639-0</th>\n",
              "      <td>[where are the players from?, which player wen...</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[[(0, 4), (1, 4), (2, 4), (3, 4), (4, 4), (5, ...</td>\n",
              "      <td>[[Louisiana State University, Valley HS (Las V...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>nt-639-1</th>\n",
              "      <td>[who are the players?, which ones are in the t...</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, ...</td>\n",
              "      <td>[[Ben McDonald, Tyler Houston, Roger Salkeld, ...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>nt-639-2</th>\n",
              "      <td>[who are the players in the top 26?, of those,...</td>\n",
              "      <td>table_csv/203_149.csv</td>\n",
              "      <td>[[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, ...</td>\n",
              "      <td>[[Ben McDonald, Tyler Houston, Roger Salkeld, ...</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "                                                      question  ...                                        answer_text\n",
              "sequence_id                                                     ...                                                   \n",
              "ns-1292-0    [who are all the athletes?, where are they fro...  ...  [[Tommy Green, Janis Dalins, Ugo Frigerio, Kar...\n",
              "nt-10730-0   [what was the production numbers of each revol...  ...  [[1,900 (estimated), 14,500 (estimated), 6,000...\n",
              "nt-10730-1   [what three revolver models had the least amou...  ...  [[Remington-Beals Army Model Revolver, New Mod...\n",
              "nt-10730-2   [what are all of the remington models?, how ma...  ...  [[Remington-Beals Army Model Revolver, Remingt...\n",
              "nt-11649-0   [what are all the names of the teams?, of thes...  ...  [[Cordoba CF, CD Malaga, Granada CF, UD Las Pa...\n",
              "nt-11649-1   [what are the losses?, what team had more than...  ...  [[6, 6, 9, 10, 10, 12, 12, 11, 13, 14, 15, 14,...\n",
              "nt-11649-2   [what were all the teams?, what were the loss ...  ...  [[Cordoba CF, CD Malaga, Granada CF, UD Las Pa...\n",
              "nt-639-0     [where are the players from?, which player wen...  ...  [[Louisiana State University, Valley HS (Las V...\n",
              "nt-639-1     [who are the players?, which ones are in the t...  ...  [[Ben McDonald, Tyler Houston, Roger Salkeld, ...\n",
              "nt-639-2     [who are the players in the top 26?, of those,...  ...  [[Ben McDonald, Tyler Houston, Roger Salkeld, ...\n",
              "\n",
              "[10 rows x 4 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r6RKTkSeLLyJ"
      },
      "source": [
        "Each row in the dataframe above now consists of a **table and one or more questions** which are asked in a **sequence**. Let's visualize the first row, i.e. a table, together with its queries:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 525
        },
        "id": "J-dTi5omLdN_",
        "outputId": "b8e1d893-8d8b-4540-dc35-57586312c992"
      },
      "source": [
        "# path to the directory containing all csv files\r\n",
        "table_csv_path = \"table_csv\"\r\n",
        "\r\n",
        "item = grouped.iloc[0]\r\n",
        "table = pd.read_csv(table_csv_path + item.table_file[9:]).astype(str) \r\n",
        "\r\n",
        "display(table)\r\n",
        "print(\"\")\r\n",
        "print(item.question)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>Rank</th>\n",
              "      <th>Name</th>\n",
              "      <th>Nationality</th>\n",
              "      <th>Time (hand)</th>\n",
              "      <th>Notes</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>nan</td>\n",
              "      <td>Tommy Green</td>\n",
              "      <td>Great Britain</td>\n",
              "      <td>4:50:10</td>\n",
              "      <td>OR</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>nan</td>\n",
              "      <td>Janis Dalins</td>\n",
              "      <td>Latvia</td>\n",
              "      <td>4:57:20</td>\n",
              "      <td>nan</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>nan</td>\n",
              "      <td>Ugo Frigerio</td>\n",
              "      <td>Italy</td>\n",
              "      <td>4:59:06</td>\n",
              "      <td>nan</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>4.0</td>\n",
              "      <td>Karl Hahnel</td>\n",
              "      <td>Germany</td>\n",
              "      <td>5:06:06</td>\n",
              "      <td>nan</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>5.0</td>\n",
              "      <td>Ettore Rivolta</td>\n",
              "      <td>Italy</td>\n",
              "      <td>5:07:39</td>\n",
              "      <td>nan</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>6.0</td>\n",
              "      <td>Paul Sievert</td>\n",
              "      <td>Germany</td>\n",
              "      <td>5:16:41</td>\n",
              "      <td>nan</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>7.0</td>\n",
              "      <td>Henri Quintric</td>\n",
              "      <td>France</td>\n",
              "      <td>5:27:25</td>\n",
              "      <td>nan</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>8.0</td>\n",
              "      <td>Ernie Crosbie</td>\n",
              "      <td>United States</td>\n",
              "      <td>5:28:02</td>\n",
              "      <td>nan</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>9.0</td>\n",
              "      <td>Bill Chisholm</td>\n",
              "      <td>United States</td>\n",
              "      <td>5:51:00</td>\n",
              "      <td>nan</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>9</th>\n",
              "      <td>10.0</td>\n",
              "      <td>Alfred Maasik</td>\n",
              "      <td>Estonia</td>\n",
              "      <td>6:19:00</td>\n",
              "      <td>nan</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>10</th>\n",
              "      <td>nan</td>\n",
              "      <td>Henry Cieman</td>\n",
              "      <td>Canada</td>\n",
              "      <td>nan</td>\n",
              "      <td>DNF</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>11</th>\n",
              "      <td>nan</td>\n",
              "      <td>John Moralis</td>\n",
              "      <td>Greece</td>\n",
              "      <td>nan</td>\n",
              "      <td>DNF</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>12</th>\n",
              "      <td>nan</td>\n",
              "      <td>Francesco Pretti</td>\n",
              "      <td>Italy</td>\n",
              "      <td>nan</td>\n",
              "      <td>DNF</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>13</th>\n",
              "      <td>nan</td>\n",
              "      <td>Arthur Tell Schwab</td>\n",
              "      <td>Switzerland</td>\n",
              "      <td>nan</td>\n",
              "      <td>DNF</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>14</th>\n",
              "      <td>nan</td>\n",
              "      <td>Harry Hinkel</td>\n",
              "      <td>United States</td>\n",
              "      <td>nan</td>\n",
              "      <td>DNF</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "    Rank                Name    Nationality Time (hand) Notes\n",
              "0    nan         Tommy Green  Great Britain     4:50:10    OR\n",
              "1    nan        Janis Dalins         Latvia     4:57:20   nan\n",
              "2    nan        Ugo Frigerio          Italy     4:59:06   nan\n",
              "3    4.0         Karl Hahnel        Germany     5:06:06   nan\n",
              "4    5.0      Ettore Rivolta          Italy     5:07:39   nan\n",
              "5    6.0        Paul Sievert        Germany     5:16:41   nan\n",
              "6    7.0      Henri Quintric         France     5:27:25   nan\n",
              "7    8.0       Ernie Crosbie  United States     5:28:02   nan\n",
              "8    9.0       Bill Chisholm  United States     5:51:00   nan\n",
              "9   10.0       Alfred Maasik        Estonia     6:19:00   nan\n",
              "10   nan        Henry Cieman         Canada         nan   DNF\n",
              "11   nan        John Moralis         Greece         nan   DNF\n",
              "12   nan    Francesco Pretti          Italy         nan   DNF\n",
              "13   nan  Arthur Tell Schwab    Switzerland         nan   DNF\n",
              "14   nan        Harry Hinkel  United States         nan   DNF"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\n",
            "['who are all the athletes?', 'where are they from?', 'along with paul sievert, which athlete is from germany?']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yw8MqIExLnnq"
      },
      "source": [
        "We can see that there are 3 sequential questions asked related to the contents of the table. \r\n",
        "\r\n",
        "We can now use `TapasTokenizer` to batch encode this, as follows:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "t5iU5byAICWb"
      },
      "source": [
        "import torch\r\n",
        "from transformers import TapasTokenizer\r\n",
        "\r\n",
        "# initialize the tokenizer\r\n",
        "tokenizer = TapasTokenizer.from_pretrained(\"google/tapas-base\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "5qOBiUPEGgK8",
        "outputId": "7bc36d39-21be-433e-ecde-3f0d81c340ea"
      },
      "source": [
        "encoding = tokenizer(table=table, queries=item.question, answer_coordinates=item.answer_coordinates, answer_text=item.answer_text,\r\n",
        "                     truncation=True, padding=\"max_length\", return_tensors=\"pt\")\r\n",
        "encoding.keys()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "dict_keys(['input_ids', 'labels', 'numeric_values', 'numeric_values_scale', 'token_type_ids', 'attention_mask'])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y2JRiKjPRHAF"
      },
      "source": [
        "TAPAS basically flattens every table-question pair before feeding it into a BERT like model:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 137
        },
        "id": "lhipz2_GRNKQ",
        "outputId": "a3ad3993-5173-45c7-a43b-993ab42f77e3"
      },
      "source": [
        "tokenizer.decode(encoding[\"input_ids\"][0])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'[CLS] who are all the athletes? [SEP] rank name nationality time ( hand ) notes [EMPTY] tommy green great britain 4 : 50 : 10 or [EMPTY] janis dalins latvia 4 : 57 : 20 [EMPTY] [EMPTY] ugo frigerio italy 4 : 59 : 06 [EMPTY] 4. 0 karl hahnel germany 5 : 06 : 06 [EMPTY] 5. 0 ettore rivolta italy 5 : 07 : 39 [EMPTY] 6. 0 paul sievert germany 5 : 16 : 41 [EMPTY] 7. 0 henri quintric france 5 : 27 : 25 [EMPTY] 8. 0 ernie crosbie united states 5 : 28 : 02 [EMPTY] 9. 0 bill chisholm united states 5 : 51 : 00 [EMPTY] 10. 0 alfred maasik estonia 6 : 19 : 00 [EMPTY] [EMPTY] henry cieman canada [EMPTY] dnf [EMPTY] john moralis greece [EMPTY] dnf [EMPTY] francesco pretti italy [EMPTY] dnf [EMPTY] arthur tell schwab switzerland [EMPTY] dnf [EMPTY] harry hinkel united states [EMPTY] dnf [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nVeB5IPaN5oN"
      },
      "source": [
        "The `token_type_ids` created here will be of shape (batch_size, sequence_length, 7), as TAPAS uses 7 different token types to encode tabular structure. Let's verify this:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zM0v-pwbN6gR"
      },
      "source": [
        "assert encoding[\"token_type_ids\"].shape == (3, 512, 7)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TMt7cWJMLvue"
      },
      "source": [
        "\r\n",
        "\r\n",
        "One thing we can verify is whether the `prev_label` token type ids are created correctly. These indicate which tokens were (part of) an answer to the previous table-question pair. \r\n",
        "\r\n",
        "The prev_label token type ids of the first example in a batch must always be zero (since there's no previous table-question pair). Let's verify this:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ytUk-H1yL9cc"
      },
      "source": [
        "assert encoding[\"token_type_ids\"][0][:,3].sum() == 0"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rJ_o-82nMfK5"
      },
      "source": [
        "However, the `prev_label` token type ids of the second table-question pair in the batch must be set to 1 for the tokens which were an answer to the previous (i.e. the first) table question pair in the batch. The answers to the first table-question pair are the following:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yxT9h2LIMNt3",
        "outputId": "69b29df5-8103-4b55-e8f4-598bd637a546"
      },
      "source": [
        "print(item.answer_text[0])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "['Tommy Green', 'Janis Dalins', 'Ugo Frigerio', 'Karl Hahnel', 'Ettore Rivolta', 'Paul Sievert', 'Henri Quintric', 'Ernie Crosbie', 'Bill Chisholm', 'Alfred Maasik', 'Henry Cieman', 'John Moralis', 'Francesco Pretti', 'Arthur Tell Schwab', 'Harry Hinkel']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CSUkMGAcMpfE"
      },
      "source": [
        "So let's now verify whether the `prev_label` ids of the second table-question pair are set correctly:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Uv6P7OpJGxuu",
        "outputId": "69b92a6a-408f-48f8-9842-dd3842f7188c"
      },
      "source": [
        "for id, prev_label in zip (encoding[\"input_ids\"][1], encoding[\"token_type_ids\"][1][:,3]):\r\n",
        "  if id != 0: # we skip padding tokens\r\n",
        "    print(tokenizer.decode([id]), prev_label.item())"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[CLS] 0\n",
            "where 0\n",
            "are 0\n",
            "they 0\n",
            "from 0\n",
            "? 0\n",
            "[SEP] 0\n",
            "rank 0\n",
            "name 0\n",
            "nationality 0\n",
            "time 0\n",
            "( 0\n",
            "hand 0\n",
            ") 0\n",
            "notes 0\n",
            "[EMPTY] 0\n",
            "tommy 1\n",
            "green 1\n",
            "great 0\n",
            "britain 0\n",
            "4 0\n",
            ": 0\n",
            "50 0\n",
            ": 0\n",
            "10 0\n",
            "or 0\n",
            "[EMPTY] 0\n",
            "jan 1\n",
            "##is 1\n",
            "dali 1\n",
            "##ns 1\n",
            "latvia 0\n",
            "4 0\n",
            ": 0\n",
            "57 0\n",
            ": 0\n",
            "20 0\n",
            "[EMPTY] 0\n",
            "[EMPTY] 0\n",
            "u 1\n",
            "##go 1\n",
            "fr 1\n",
            "##iger 1\n",
            "##io 1\n",
            "italy 0\n",
            "4 0\n",
            ": 0\n",
            "59 0\n",
            ": 0\n",
            "06 0\n",
            "[EMPTY] 0\n",
            "4 0\n",
            ". 0\n",
            "0 0\n",
            "karl 1\n",
            "hahn 1\n",
            "##el 1\n",
            "germany 0\n",
            "5 0\n",
            ": 0\n",
            "06 0\n",
            ": 0\n",
            "06 0\n",
            "[EMPTY] 0\n",
            "5 0\n",
            ". 0\n",
            "0 0\n",
            "et 1\n",
            "##tore 1\n",
            "ri 1\n",
            "##vo 1\n",
            "##lta 1\n",
            "italy 0\n",
            "5 0\n",
            ": 0\n",
            "07 0\n",
            ": 0\n",
            "39 0\n",
            "[EMPTY] 0\n",
            "6 0\n",
            ". 0\n",
            "0 0\n",
            "paul 1\n",
            "si 1\n",
            "##ever 1\n",
            "##t 1\n",
            "germany 0\n",
            "5 0\n",
            ": 0\n",
            "16 0\n",
            ": 0\n",
            "41 0\n",
            "[EMPTY] 0\n",
            "7 0\n",
            ". 0\n",
            "0 0\n",
            "henri 1\n",
            "qui 1\n",
            "##nt 1\n",
            "##ric 1\n",
            "france 0\n",
            "5 0\n",
            ": 0\n",
            "27 0\n",
            ": 0\n",
            "25 0\n",
            "[EMPTY] 0\n",
            "8 0\n",
            ". 0\n",
            "0 0\n",
            "ernie 1\n",
            "cr 1\n",
            "##os 1\n",
            "##bie 1\n",
            "united 0\n",
            "states 0\n",
            "5 0\n",
            ": 0\n",
            "28 0\n",
            ": 0\n",
            "02 0\n",
            "[EMPTY] 0\n",
            "9 0\n",
            ". 0\n",
            "0 0\n",
            "bill 1\n",
            "chi 1\n",
            "##sho 1\n",
            "##lm 1\n",
            "united 0\n",
            "states 0\n",
            "5 0\n",
            ": 0\n",
            "51 0\n",
            ": 0\n",
            "00 0\n",
            "[EMPTY] 0\n",
            "10 0\n",
            ". 0\n",
            "0 0\n",
            "alfred 1\n",
            "ma 1\n",
            "##asi 1\n",
            "##k 1\n",
            "estonia 0\n",
            "6 0\n",
            ": 0\n",
            "19 0\n",
            ": 0\n",
            "00 0\n",
            "[EMPTY] 0\n",
            "[EMPTY] 0\n",
            "henry 1\n",
            "ci 1\n",
            "##eman 1\n",
            "canada 0\n",
            "[EMPTY] 0\n",
            "d 0\n",
            "##n 0\n",
            "##f 0\n",
            "[EMPTY] 0\n",
            "john 1\n",
            "moral 1\n",
            "##is 1\n",
            "greece 0\n",
            "[EMPTY] 0\n",
            "d 0\n",
            "##n 0\n",
            "##f 0\n",
            "[EMPTY] 0\n",
            "francesco 1\n",
            "pre 1\n",
            "##tti 1\n",
            "italy 0\n",
            "[EMPTY] 0\n",
            "d 0\n",
            "##n 0\n",
            "##f 0\n",
            "[EMPTY] 0\n",
            "arthur 1\n",
            "tell 1\n",
            "sc 1\n",
            "##hwa 1\n",
            "##b 1\n",
            "switzerland 0\n",
            "[EMPTY] 0\n",
            "d 0\n",
            "##n 0\n",
            "##f 0\n",
            "[EMPTY] 0\n",
            "harry 1\n",
            "hi 1\n",
            "##nk 1\n",
            "##el 1\n",
            "united 0\n",
            "states 0\n",
            "[EMPTY] 0\n",
            "d 0\n",
            "##n 0\n",
            "##f 0\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wjVk49fO6u8H"
      },
      "source": [
        "This looks OK! Be sure to check this, because the token type ids are critical for the performance of TAPAS.\r\n",
        "\r\n",
        "Let's create a PyTorch dataset and corresponding dataloader. Note the __getitem__ method here: in order to properly set the prev_labels token types, we must check whether a table-question pair is the first in a sequence or not. In case it is, we can just encode it. In case it isn't, we need to encode it together with the previous table-question pair.\r\n",
        "\r\n",
        "Note that this is not the most efficient approach, because we're effectively tokenizing each table-question pair twice when applied on the entire dataset (feel free to ping me a more efficient solution)."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C-n9vDTD1-k9"
      },
      "source": [
        "class TableDataset(torch.utils.data.Dataset):\r\n",
        "    def __init__(self, df, tokenizer):\r\n",
        "        self.df = df\r\n",
        "        self.tokenizer = tokenizer\r\n",
        "\r\n",
        "    def __getitem__(self, idx):\r\n",
        "        item = self.df.iloc[idx]\r\n",
        "        table = pd.read_csv(table_csv_path + item.table_file[9:]).astype(str) # TapasTokenizer expects the table data to be text only\r\n",
        "        if item.position != 0:\r\n",
        "          # use the previous table-question pair to correctly set the prev_labels token type ids\r\n",
        "          previous_item = self.df.iloc[idx-1]\r\n",
        "          encoding = self.tokenizer(table=table, \r\n",
        "                                    queries=[previous_item.question, item.question], \r\n",
        "                                    answer_coordinates=[previous_item.answer_coordinates, item.answer_coordinates], \r\n",
        "                                    answer_text=[previous_item.answer_text, item.answer_text],\r\n",
        "                                    padding=\"max_length\",\r\n",
        "                                    truncation=True,\r\n",
        "                                    return_tensors=\"pt\"\r\n",
        "          )\r\n",
        "          # use encodings of second table-question pair in the batch\r\n",
        "          encoding = {key: val[-1] for key, val in encoding.items()}\r\n",
        "        else:\r\n",
        "          # this means it's the first table-question pair in a sequence\r\n",
        "          encoding = self.tokenizer(table=table, \r\n",
        "                                    queries=item.question, \r\n",
        "                                    answer_coordinates=item.answer_coordinates, \r\n",
        "                                    answer_text=item.answer_text,\r\n",
        "                                    padding=\"max_length\",\r\n",
        "                                    truncation=True,\r\n",
        "                                    return_tensors=\"pt\"\r\n",
        "          )\r\n",
        "          # remove the batch dimension which the tokenizer adds \r\n",
        "          encoding = {key: val.squeeze(0) for key, val in encoding.items()}\r\n",
        "        return encoding\r\n",
        "\r\n",
        "    def __len__(self):\r\n",
        "        return len(self.df)\r\n",
        "\r\n",
        "train_dataset = TableDataset(df=data, tokenizer=tokenizer)\r\n",
        "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "X4CHgnTzwfNp",
        "outputId": "a0980f27-a317-4375-9bc4-0085acad0e5f"
      },
      "source": [
        "train_dataset[0][\"token_type_ids\"].shape"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([512, 7])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 19
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bZN1psdBy5_s",
        "outputId": "e085d737-3a6f-45e5-c200-7c21916b284a"
      },
      "source": [
        "train_dataset[1][\"input_ids\"].shape"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([512])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 20
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pHAyf85k_xQt"
      },
      "source": [
        "batch = next(iter(train_dataloader))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FoqySHh-_0JV",
        "outputId": "9c0ab5d9-0a06-4331-80e2-ba3b739cfa92"
      },
      "source": [
        "batch[\"input_ids\"].shape"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([2, 512])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "g5pjJCCT_53N",
        "outputId": "d2ebbf1e-0701-47c1-8533-a087892bd715"
      },
      "source": [
        "batch[\"token_type_ids\"].shape"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([2, 512, 7])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 23
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xVb1-H-jAEoS"
      },
      "source": [
        "Let's decode the first table-question pair:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 137
        },
        "id": "1vfjT1JC_7zI",
        "outputId": "f1a85d76-96ab-4a4d-f8ae-c7ee913c6d7f"
      },
      "source": [
        "tokenizer.decode(batch[\"input_ids\"][0])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'[CLS] where are the players from? [SEP] pick player team position school 1 ben mcdonald baltimore orioles rhp louisiana state university 2 tyler houston atlanta braves c valley hs ( las vegas, nv ) 3 roger salkeld seattle mariners rhp saugus ( ca ) hs 4 jeff jackson philadelphia phillies of simeon hs ( chicago, il ) 5 donald harris texas rangers of texas tech university 6 paul coleman saint louis cardinals of frankston ( tx ) hs 7 frank thomas chicago white sox 1b auburn university 8 earl cunningham chicago cubs of lancaster ( sc ) hs 9 kyle abbott california angels lhp long beach state university 10 charles johnson montreal expos c westwood hs ( fort pierce, fl ) 11 calvin murray cleveland indians 3b w. t. white high school ( dallas, tx ) 12 jeff juden houston astros rhp salem ( ma ) hs 13 brent mayne kansas city royals c cal state fullerton 14 steve hosey san francisco giants of fresno state university 15 kiki jones los angeles dodgers rhp hillsborough hs ( tampa, fl ) 16 greg blosser boston red sox of sarasota ( fl ) hs 17 cal eldred milwaukee brewers rhp university of iowa 18 willie greene pittsburgh pirates ss jones county hs ( gray, ga ) 19 eddie zosky toronto blue jays ss fresno state university 20 scott bryant cincinnati reds of university of texas 21 greg gohr detroit tigers rhp santa clara university 22 tom goodwin los angeles dodgers of fresno state university 23 mo vaughn boston red sox 1b seton hall university 24 alan zinter new york mets c university of arizona 25 chuck knoblauch minnesota twins 2b texas a & m university 26 scott burrell seattle mariners rhp hamden ( ct ) hs [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 24
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sujsp8o9DtsY"
      },
      "source": [
        "#first example should not have any prev_labels set\r\n",
        "assert batch[\"token_type_ids\"][0][:,3].sum() == 0"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EIeql5vfFI6s"
      },
      "source": [
        "Let's decode the second table-question pair and verify some more:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 137
        },
        "id": "WrNo_qMqFOzi",
        "outputId": "b2051f0b-72d8-42e2-a6b6-c5a40eda666f"
      },
      "source": [
        "tokenizer.decode(batch[\"input_ids\"][1])"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'[CLS] which player went to louisiana state university? [SEP] pick player team position school 1 ben mcdonald baltimore orioles rhp louisiana state university 2 tyler houston atlanta braves c valley hs ( las vegas, nv ) 3 roger salkeld seattle mariners rhp saugus ( ca ) hs 4 jeff jackson philadelphia phillies of simeon hs ( chicago, il ) 5 donald harris texas rangers of texas tech university 6 paul coleman saint louis cardinals of frankston ( tx ) hs 7 frank thomas chicago white sox 1b auburn university 8 earl cunningham chicago cubs of lancaster ( sc ) hs 9 kyle abbott california angels lhp long beach state university 10 charles johnson montreal expos c westwood hs ( fort pierce, fl ) 11 calvin murray cleveland indians 3b w. t. white high school ( dallas, tx ) 12 jeff juden houston astros rhp salem ( ma ) hs 13 brent mayne kansas city royals c cal state fullerton 14 steve hosey san francisco giants of fresno state university 15 kiki jones los angeles dodgers rhp hillsborough hs ( tampa, fl ) 16 greg blosser boston red sox of sarasota ( fl ) hs 17 cal eldred milwaukee brewers rhp university of iowa 18 willie greene pittsburgh pirates ss jones county hs ( gray, ga ) 19 eddie zosky toronto blue jays ss fresno state university 20 scott bryant cincinnati reds of university of texas 21 greg gohr detroit tigers rhp santa clara university 22 tom goodwin los angeles dodgers of fresno state university 23 mo vaughn boston red sox 1b seton hall university 24 alan zinter new york mets c university of arizona 25 chuck knoblauch minnesota twins 2b texas a & m university 26 scott burrell seattle mariners rhp hamden ( ct ) hs [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 26
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9a1OToVqxNap",
        "outputId": "2040d63a-024f-4e65-e17c-2a51630a9226"
      },
      "source": [
        "assert batch[\"labels\"][0].sum() == batch[\"token_type_ids\"][1][:,3].sum()\r\n",
        "print(batch[\"token_type_ids\"][1][:,3].sum())"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "tensor(132)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "x4PRdYvBE1k3",
        "outputId": "31bc6092-57e8-4040-f410-83e314ee4a0f"
      },
      "source": [
        "for id, prev_label in zip(batch[\"input_ids\"][1], batch[\"token_type_ids\"][1][:,3]):\r\n",
        "  if id != 0:\r\n",
        "    print(tokenizer.decode([id]), prev_label.item())"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "[CLS] 0\n",
            "which 0\n",
            "player 0\n",
            "went 0\n",
            "to 0\n",
            "louisiana 0\n",
            "state 0\n",
            "university 0\n",
            "? 0\n",
            "[SEP] 0\n",
            "pick 0\n",
            "player 0\n",
            "team 0\n",
            "position 0\n",
            "school 0\n",
            "1 0\n",
            "ben 0\n",
            "mcdonald 0\n",
            "baltimore 0\n",
            "orioles 0\n",
            "r 0\n",
            "##hp 0\n",
            "louisiana 1\n",
            "state 1\n",
            "university 1\n",
            "2 0\n",
            "tyler 0\n",
            "houston 0\n",
            "atlanta 0\n",
            "braves 0\n",
            "c 0\n",
            "valley 1\n",
            "hs 1\n",
            "( 1\n",
            "las 1\n",
            "vegas 1\n",
            ", 1\n",
            "n 1\n",
            "##v 1\n",
            ") 1\n",
            "3 0\n",
            "roger 0\n",
            "sal 0\n",
            "##kel 0\n",
            "##d 0\n",
            "seattle 0\n",
            "mariners 0\n",
            "r 0\n",
            "##hp 0\n",
            "sa 1\n",
            "##ug 1\n",
            "##us 1\n",
            "( 1\n",
            "ca 1\n",
            ") 1\n",
            "hs 1\n",
            "4 0\n",
            "jeff 0\n",
            "jackson 0\n",
            "philadelphia 0\n",
            "phillies 0\n",
            "of 0\n",
            "simeon 1\n",
            "hs 1\n",
            "( 1\n",
            "chicago 1\n",
            ", 1\n",
            "il 1\n",
            ") 1\n",
            "5 0\n",
            "donald 0\n",
            "harris 0\n",
            "texas 0\n",
            "rangers 0\n",
            "of 0\n",
            "texas 1\n",
            "tech 1\n",
            "university 1\n",
            "6 0\n",
            "paul 0\n",
            "coleman 0\n",
            "saint 0\n",
            "louis 0\n",
            "cardinals 0\n",
            "of 0\n",
            "franks 1\n",
            "##ton 1\n",
            "( 1\n",
            "tx 1\n",
            ") 1\n",
            "hs 1\n",
            "7 0\n",
            "frank 0\n",
            "thomas 0\n",
            "chicago 0\n",
            "white 0\n",
            "sox 0\n",
            "1b 0\n",
            "auburn 1\n",
            "university 1\n",
            "8 0\n",
            "earl 0\n",
            "cunningham 0\n",
            "chicago 0\n",
            "cubs 0\n",
            "of 0\n",
            "lancaster 1\n",
            "( 1\n",
            "sc 1\n",
            ") 1\n",
            "hs 1\n",
            "9 0\n",
            "kyle 0\n",
            "abbott 0\n",
            "california 0\n",
            "angels 0\n",
            "l 0\n",
            "##hp 0\n",
            "long 1\n",
            "beach 1\n",
            "state 1\n",
            "university 1\n",
            "10 0\n",
            "charles 0\n",
            "johnson 0\n",
            "montreal 0\n",
            "expo 0\n",
            "##s 0\n",
            "c 0\n",
            "westwood 1\n",
            "hs 1\n",
            "( 1\n",
            "fort 1\n",
            "pierce 1\n",
            ", 1\n",
            "fl 1\n",
            ") 1\n",
            "11 0\n",
            "calvin 0\n",
            "murray 0\n",
            "cleveland 0\n",
            "indians 0\n",
            "3 0\n",
            "##b 0\n",
            "w 1\n",
            ". 1\n",
            "t 1\n",
            ". 1\n",
            "white 1\n",
            "high 1\n",
            "school 1\n",
            "( 1\n",
            "dallas 1\n",
            ", 1\n",
            "tx 1\n",
            ") 1\n",
            "12 0\n",
            "jeff 0\n",
            "jude 0\n",
            "##n 0\n",
            "houston 0\n",
            "astros 0\n",
            "r 0\n",
            "##hp 0\n",
            "salem 1\n",
            "( 1\n",
            "ma 1\n",
            ") 1\n",
            "hs 1\n",
            "13 0\n",
            "brent 0\n",
            "may 0\n",
            "##ne 0\n",
            "kansas 0\n",
            "city 0\n",
            "royals 0\n",
            "c 0\n",
            "cal 1\n",
            "state 1\n",
            "fuller 1\n",
            "##ton 1\n",
            "14 0\n",
            "steve 0\n",
            "hose 0\n",
            "##y 0\n",
            "san 0\n",
            "francisco 0\n",
            "giants 0\n",
            "of 0\n",
            "fresno 1\n",
            "state 1\n",
            "university 1\n",
            "15 0\n",
            "ki 0\n",
            "##ki 0\n",
            "jones 0\n",
            "los 0\n",
            "angeles 0\n",
            "dodgers 0\n",
            "r 0\n",
            "##hp 0\n",
            "hillsborough 1\n",
            "hs 1\n",
            "( 1\n",
            "tampa 1\n",
            ", 1\n",
            "fl 1\n",
            ") 1\n",
            "16 0\n",
            "greg 0\n",
            "b 0\n",
            "##los 0\n",
            "##ser 0\n",
            "boston 0\n",
            "red 0\n",
            "sox 0\n",
            "of 0\n",
            "sara 1\n",
            "##so 1\n",
            "##ta 1\n",
            "( 1\n",
            "fl 1\n",
            ") 1\n",
            "hs 1\n",
            "17 0\n",
            "cal 0\n",
            "el 0\n",
            "##dre 0\n",
            "##d 0\n",
            "milwaukee 0\n",
            "brewers 0\n",
            "r 0\n",
            "##hp 0\n",
            "university 1\n",
            "of 1\n",
            "iowa 1\n",
            "18 0\n",
            "willie 0\n",
            "greene 0\n",
            "pittsburgh 0\n",
            "pirates 0\n",
            "ss 0\n",
            "jones 1\n",
            "county 1\n",
            "hs 1\n",
            "( 1\n",
            "gray 1\n",
            ", 1\n",
            "ga 1\n",
            ") 1\n",
            "19 0\n",
            "eddie 0\n",
            "z 0\n",
            "##os 0\n",
            "##ky 0\n",
            "toronto 0\n",
            "blue 0\n",
            "jays 0\n",
            "ss 0\n",
            "fresno 1\n",
            "state 1\n",
            "university 1\n",
            "20 0\n",
            "scott 0\n",
            "bryant 0\n",
            "cincinnati 0\n",
            "reds 0\n",
            "of 0\n",
            "university 1\n",
            "of 1\n",
            "texas 1\n",
            "21 0\n",
            "greg 0\n",
            "go 0\n",
            "##hr 0\n",
            "detroit 0\n",
            "tigers 0\n",
            "r 0\n",
            "##hp 0\n",
            "santa 1\n",
            "clara 1\n",
            "university 1\n",
            "22 0\n",
            "tom 0\n",
            "goodwin 0\n",
            "los 0\n",
            "angeles 0\n",
            "dodgers 0\n",
            "of 0\n",
            "fresno 1\n",
            "state 1\n",
            "university 1\n",
            "23 0\n",
            "mo 0\n",
            "vaughn 0\n",
            "boston 0\n",
            "red 0\n",
            "sox 0\n",
            "1b 0\n",
            "seton 1\n",
            "hall 1\n",
            "university 1\n",
            "24 0\n",
            "alan 0\n",
            "z 0\n",
            "##int 0\n",
            "##er 0\n",
            "new 0\n",
            "york 0\n",
            "mets 0\n",
            "c 0\n",
            "university 1\n",
            "of 1\n",
            "arizona 1\n",
            "25 0\n",
            "chuck 0\n",
            "knob 0\n",
            "##lau 0\n",
            "##ch 0\n",
            "minnesota 0\n",
            "twins 0\n",
            "2 0\n",
            "##b 0\n",
            "texas 1\n",
            "a 1\n",
            "& 1\n",
            "m 1\n",
            "university 1\n",
            "26 0\n",
            "scott 0\n",
            "burr 0\n",
            "##ell 0\n",
            "seattle 0\n",
            "mariners 0\n",
            "r 0\n",
            "##hp 0\n",
            "ham 1\n",
            "##den 1\n",
            "( 1\n",
            "ct 1\n",
            ") 1\n",
            "hs 1\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cAem9QnIxoKb"
      },
      "source": [
        "## Define the model\r\n",
        "\r\n",
        "Here we initialize the model with a pre-trained base and randomly initialized cell selection head, and move it to the GPU (if available).\r\n",
        "\r\n",
        "Note that the `google/tapas-base` checkpoint has (by default) an SQA configuration, so we don't need to specify any additional hyperparameters."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000,
          "referenced_widgets": [
            "768723f1af4c4bf497064a9796567382",
            "82bfe5d88c5546358de33952e42221a3",
            "c6aadbda246d47fcad9660c2f57926a0",
            "cd910003940b498e94db9746091e4c3c",
            "a11bc2e0d34543d1bb0ab708e0cc1a67",
            "d59737fa0a424a90ba02110e4cf1b639",
            "d534c2ebdbb144efa59500fcadf6e118",
            "f6fef10b29f74c458d05c2cbda46bf4a",
            "695d0c44ebbe4a55a17c735645c26c82",
            "7e3cd6c49143436bad3d84be7ca2cf79",
            "be6159a5be0945629a517297df1170b8",
            "b889ffa28a1949e9ac46e205ee582688",
            "3c95dd6249784fe6a2a466737e5f5866",
            "cf9382ee988f4787a88340d3b2af95f3",
            "45fdca8a04f94b43b021c590181e8a48",
            "3bad5ab136084c2e92b55153828a403f"
          ]
        },
        "id": "_OsPodbiDliR",
        "outputId": "e2094861-fc6c-42b9-b12b-0f6824ad3048"
      },
      "source": [
        "from transformers import TapasForQuestionAnswering\r\n",
        "\r\n",
        "model = TapasForQuestionAnswering.from_pretrained(\"google/tapas-base\")\r\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n",
        "\r\n",
        "model.to(device)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "768723f1af4c4bf497064a9796567382",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1432.0, style=ProgressStyle(description…"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "695d0c44ebbe4a55a17c735645c26c82",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=442768791.0, style=ProgressStyle(descri…"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "Some weights of TapasForQuestionAnswering were not initialized from the model checkpoint at google/tapas-base and are newly initialized: ['column_output_bias', 'output_bias', 'column_output_weights', 'output_weights']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "TapasForQuestionAnswering(\n",
              "  (tapas): TapasModel(\n",
              "    (embeddings): TapasEmbeddings(\n",
              "      (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
              "      (position_embeddings): Embedding(1024, 768)\n",
              "      (token_type_embeddings_0): Embedding(3, 768)\n",
              "      (token_type_embeddings_1): Embedding(256, 768)\n",
              "      (token_type_embeddings_2): Embedding(256, 768)\n",
              "      (token_type_embeddings_3): Embedding(2, 768)\n",
              "      (token_type_embeddings_4): Embedding(256, 768)\n",
              "      (token_type_embeddings_5): Embedding(256, 768)\n",
              "      (token_type_embeddings_6): Embedding(10, 768)\n",
              "      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "      (dropout): Dropout(p=0.07, inplace=False)\n",
              "    )\n",
              "    (encoder): TapasEncoder(\n",
              "      (layer): ModuleList(\n",
              "        (0): TapasLayer(\n",
              "          (attention): TapasAttention(\n",
              "            (self): TapasSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): TapasSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.07, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): TapasIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): TapasOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.07, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (1): TapasLayer(\n",
              "          (attention): TapasAttention(\n",
              "            (self): TapasSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): TapasSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.07, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): TapasIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): TapasOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.07, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (2): TapasLayer(\n",
              "          (attention): TapasAttention(\n",
              "            (self): TapasSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): TapasSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.07, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): TapasIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): TapasOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.07, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (3): TapasLayer(\n",
              "          (attention): TapasAttention(\n",
              "            (self): TapasSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): TapasSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.07, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): TapasIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): TapasOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.07, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (4): TapasLayer(\n",
              "          (attention): TapasAttention(\n",
              "            (self): TapasSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): TapasSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.07, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): TapasIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): TapasOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.07, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (5): TapasLayer(\n",
              "          (attention): TapasAttention(\n",
              "            (self): TapasSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): TapasSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.07, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): TapasIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): TapasOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.07, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (6): TapasLayer(\n",
              "          (attention): TapasAttention(\n",
              "            (self): TapasSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): TapasSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.07, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): TapasIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): TapasOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.07, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (7): TapasLayer(\n",
              "          (attention): TapasAttention(\n",
              "            (self): TapasSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): TapasSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.07, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): TapasIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): TapasOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.07, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (8): TapasLayer(\n",
              "          (attention): TapasAttention(\n",
              "            (self): TapasSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): TapasSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.07, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): TapasIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): TapasOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.07, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (9): TapasLayer(\n",
              "          (attention): TapasAttention(\n",
              "            (self): TapasSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): TapasSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.07, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): TapasIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): TapasOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.07, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (10): TapasLayer(\n",
              "          (attention): TapasAttention(\n",
              "            (self): TapasSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): TapasSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.07, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): TapasIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): TapasOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.07, inplace=False)\n",
              "          )\n",
              "        )\n",
              "        (11): TapasLayer(\n",
              "          (attention): TapasAttention(\n",
              "            (self): TapasSelfAttention(\n",
              "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): TapasSelfOutput(\n",
              "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "              (dropout): Dropout(p=0.07, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): TapasIntermediate(\n",
              "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
              "          )\n",
              "          (output): TapasOutput(\n",
              "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
              "            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
              "            (dropout): Dropout(p=0.07, inplace=False)\n",
              "          )\n",
              "        )\n",
              "      )\n",
              "    )\n",
              "    (pooler): TapasPooler(\n",
              "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
              "      (activation): Tanh()\n",
              "    )\n",
              "  )\n",
              "  (dropout): Dropout(p=0.07, inplace=False)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dtvkIFkCzdsg"
      },
      "source": [
        "## Training the model\r\n",
        "\r\n",
        "Let's fine-tune the model in well-known PyTorch fashion:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "HyEZVmbdzWV9",
        "outputId": "2a6c73a9-9572-480e-c24f-209e6c5804fb"
      },
      "source": [
        "from transformers import AdamW\r\n",
        "\r\n",
        "optimizer = AdamW(model.parameters(), lr=5e-5)\r\n",
        "\r\n",
        "for epoch in range(10):  # loop over the dataset multiple times\r\n",
        "   print(\"Epoch:\", epoch)\r\n",
        "   for idx, batch in enumerate(train_dataloader):\r\n",
        "        # get the inputs;\r\n",
        "        input_ids = batch[\"input_ids\"].to(device)\r\n",
        "        attention_mask = batch[\"attention_mask\"].to(device)\r\n",
        "        token_type_ids = batch[\"token_type_ids\"].to(device)\r\n",
        "        labels = batch[\"labels\"].to(device)\r\n",
        "        \r\n",
        "        # zero the parameter gradients\r\n",
        "        optimizer.zero_grad()\r\n",
        "        # forward + backward + optimize\r\n",
        "        outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,\r\n",
        "                       labels=labels)\r\n",
        "        loss = outputs.loss\r\n",
        "        print(\"Loss:\", loss.item())\r\n",
        "        loss.backward()\r\n",
        "        optimizer.step()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 0\n",
            "Loss: 2.283051013946533\n",
            "Loss: 2.321470022201538\n",
            "Loss: 1.304502010345459\n",
            "Loss: 1.8132498264312744\n",
            "Loss: 1.5585863590240479\n",
            "Loss: 2.8958029747009277\n",
            "Loss: 2.3229925632476807\n",
            "Loss: 2.9768738746643066\n",
            "Loss: 2.43325138092041\n",
            "Loss: 2.588594913482666\n",
            "Loss: 2.432821273803711\n",
            "Loss: 2.077129602432251\n",
            "Loss: 2.5189807415008545\n",
            "Loss: 1.117794394493103\n",
            "Epoch: 1\n",
            "Loss: 2.6343977451324463\n",
            "Loss: 1.1567877531051636\n",
            "Loss: 0.8736000061035156\n",
            "Loss: 1.1256351470947266\n",
            "Loss: 1.193580985069275\n",
            "Loss: 2.019536018371582\n",
            "Loss: 1.9377449750900269\n",
            "Loss: 2.7245376110076904\n",
            "Loss: 2.4017510414123535\n",
            "Loss: 1.299104928970337\n",
            "Loss: 1.1747502088546753\n",
            "Loss: 1.3556476831436157\n",
            "Loss: 1.2996423244476318\n",
            "Loss: 0.8333960771560669\n",
            "Epoch: 2\n",
            "Loss: 1.1792304515838623\n",
            "Loss: 0.7969248294830322\n",
            "Loss: 0.5043489336967468\n",
            "Loss: 0.8913598656654358\n",
            "Loss: 1.618265986442566\n",
            "Loss: 1.099987268447876\n",
            "Loss: 0.8575657606124878\n",
            "Loss: 1.231074333190918\n",
            "Loss: 1.5082621574401855\n",
            "Loss: 0.736372709274292\n",
            "Loss: 0.9882394671440125\n",
            "Loss: 0.8554693460464478\n",
            "Loss: 0.646073579788208\n",
            "Loss: 0.5032682418823242\n",
            "Epoch: 3\n",
            "Loss: 0.6410143375396729\n",
            "Loss: 0.6181665658950806\n",
            "Loss: 0.343902051448822\n",
            "Loss: 0.24581563472747803\n",
            "Loss: 0.6221672296524048\n",
            "Loss: 0.2922801077365875\n",
            "Loss: 0.11909369379281998\n",
            "Loss: 0.6472622156143188\n",
            "Loss: 2.407520055770874\n",
            "Loss: 1.0915907621383667\n",
            "Loss: 0.573483407497406\n",
            "Loss: 0.314322829246521\n",
            "Loss: 0.5054236650466919\n",
            "Loss: 0.3973156809806824\n",
            "Epoch: 4\n",
            "Loss: 0.1462380290031433\n",
            "Loss: 0.4669300615787506\n",
            "Loss: 0.12367662787437439\n",
            "Loss: 0.20674511790275574\n",
            "Loss: 1.7252687215805054\n",
            "Loss: 1.0967378616333008\n",
            "Loss: 0.6795816421508789\n",
            "Loss: 0.29488128423690796\n",
            "Loss: 0.6297208070755005\n",
            "Loss: 0.7574292421340942\n",
            "Loss: 2.1319501399993896\n",
            "Loss: 0.3807623088359833\n",
            "Loss: 1.796015977859497\n",
            "Loss: 0.3021317720413208\n",
            "Epoch: 5\n",
            "Loss: 2.1536195278167725\n",
            "Loss: 0.1456122249364853\n",
            "Loss: 0.2584973871707916\n",
            "Loss: 0.23840856552124023\n",
            "Loss: 0.4197732210159302\n",
            "Loss: 0.8440569639205933\n",
            "Loss: 0.587626039981842\n",
            "Loss: 0.7899488210678101\n",
            "Loss: 1.723198413848877\n",
            "Loss: 0.49360978603363037\n",
            "Loss: 0.5526523590087891\n",
            "Loss: 0.2673705220222473\n",
            "Loss: 0.4922328293323517\n",
            "Loss: 0.2313116192817688\n",
            "Epoch: 6\n",
            "Loss: 0.35565513372421265\n",
            "Loss: 0.2256387621164322\n",
            "Loss: 0.07711786031723022\n",
            "Loss: 0.062317755073308945\n",
            "Loss: 0.6037214994430542\n",
            "Loss: 0.1518193930387497\n",
            "Loss: 0.04495159536600113\n",
            "Loss: 0.47260624170303345\n",
            "Loss: 0.3101451098918915\n",
            "Loss: 0.35189706087112427\n",
            "Loss: 0.19305744767189026\n",
            "Loss: 0.1478348821401596\n",
            "Loss: 0.23966407775878906\n",
            "Loss: 0.10663879662752151\n",
            "Epoch: 7\n",
            "Loss: 0.1251879334449768\n",
            "Loss: 0.03940412029623985\n",
            "Loss: 0.037242159247398376\n",
            "Loss: 0.23117898404598236\n",
            "Loss: 0.21706533432006836\n",
            "Loss: 0.12712527811527252\n",
            "Loss: 0.02821732871234417\n",
            "Loss: 0.33739519119262695\n",
            "Loss: 0.3730405569076538\n",
            "Loss: 0.24556192755699158\n",
            "Loss: 0.10160182416439056\n",
            "Loss: 0.07098039239645004\n",
            "Loss: 0.054128892719745636\n",
            "Loss: 0.042726725339889526\n",
            "Epoch: 8\n",
            "Loss: 0.16904376447200775\n",
            "Loss: 0.01094090472906828\n",
            "Loss: 0.016752174124121666\n",
            "Loss: 0.0321066752076149\n",
            "Loss: 0.12141289561986923\n",
            "Loss: 0.21206972002983093\n",
            "Loss: 0.03060653805732727\n",
            "Loss: 0.15915940701961517\n",
            "Loss: 0.14235229790210724\n",
            "Loss: 0.21385076642036438\n",
            "Loss: 0.07888767868280411\n",
            "Loss: 0.061028994619846344\n",
            "Loss: 0.3621511459350586\n",
            "Loss: 0.021799277514219284\n",
            "Epoch: 9\n",
            "Loss: 0.027316564694046974\n",
            "Loss: 1.1551400423049927\n",
            "Loss: 0.27995944023132324\n",
            "Loss: 0.013127731159329414\n",
            "Loss: 0.1084042340517044\n",
            "Loss: 0.057816676795482635\n",
            "Loss: 0.007765657734125853\n",
            "Loss: 0.10145987570285797\n",
            "Loss: 0.18493309617042542\n",
            "Loss: 0.33809077739715576\n",
            "Loss: 0.030108436942100525\n",
            "Loss: 0.08543295413255692\n",
            "Loss: 0.025884397327899933\n",
            "Loss: 0.0188467875123024\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kmugNHCN20QF"
      },
      "source": [
        "## Inference\r\n",
        "\r\n",
        "As SQA is a bit different due to its conversational nature, we need to run every training example of the a batch one by one through the model (sequentially), overwriting the `prev_labels` token types (which were created by the tokenizer) by the answer predicted by the model. It is based on the [following code](https://github.com/google-research/tapas/blob/f458b6624b8aa75961a0ab78e9847355022940d3/tapas/experiments/prediction_utils.py#L92) from the official implementation:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MhJwoaSy26PD"
      },
      "source": [
        "import collections\r\n",
        "import numpy as np\r\n",
        "\r\n",
        "def compute_prediction_sequence(model, data, device):\r\n",
        "  \"\"\"Computes predictions using model's answers to the previous questions.\"\"\"\r\n",
        "  \r\n",
        "  # prepare data\r\n",
        "  input_ids = data[\"input_ids\"].to(device)\r\n",
        "  attention_mask = data[\"attention_mask\"].to(device)\r\n",
        "  token_type_ids = data[\"token_type_ids\"].to(device)\r\n",
        "\r\n",
        "  all_logits = []\r\n",
        "  prev_answers = None\r\n",
        "\r\n",
        "  num_batch = data[\"input_ids\"].shape[0]\r\n",
        "  \r\n",
        "  for idx in range(num_batch):\r\n",
        "    \r\n",
        "    if prev_answers is not None:\r\n",
        "        coords_to_answer = prev_answers[idx]\r\n",
        "        # Next, set the label ids predicted by the model\r\n",
        "        prev_label_ids_example = token_type_ids_example[:,3] # shape (seq_len,)\r\n",
        "        model_label_ids = np.zeros_like(prev_label_ids_example.cpu().numpy()) # shape (seq_len,)\r\n",
        "\r\n",
        "        # for each token in the sequence:\r\n",
        "        token_type_ids_example = token_type_ids[idx] # shape (seq_len, 7)\r\n",
        "        for i in range(model_label_ids.shape[0]):\r\n",
        "          segment_id = token_type_ids_example[:,0].tolist()[i]\r\n",
        "          col_id = token_type_ids_example[:,1].tolist()[i] - 1\r\n",
        "          row_id = token_type_ids_example[:,2].tolist()[i] - 1\r\n",
        "          if row_id >= 0 and col_id >= 0 and segment_id == 1:\r\n",
        "            model_label_ids[i] = int(coords_to_answer[(col_id, row_id)])\r\n",
        "\r\n",
        "        # set the prev label ids of the example (shape (1, seq_len) )\r\n",
        "        token_type_ids_example[:,3] = torch.from_numpy(model_label_ids).type(torch.long).to(device)   \r\n",
        "\r\n",
        "    prev_answers = {}\r\n",
        "    # get the example\r\n",
        "    input_ids_example = input_ids[idx] # shape (seq_len,)\r\n",
        "    attention_mask_example = attention_mask[idx] # shape (seq_len,)\r\n",
        "    token_type_ids_example = token_type_ids[idx] # shape (seq_len, 7)\r\n",
        "    # forward pass to obtain the logits\r\n",
        "    outputs = model(input_ids=input_ids_example.unsqueeze(0), \r\n",
        "                    attention_mask=attention_mask_example.unsqueeze(0), \r\n",
        "                    token_type_ids=token_type_ids_example.unsqueeze(0))\r\n",
        "    logits = outputs.logits\r\n",
        "    all_logits.append(logits)\r\n",
        "\r\n",
        "    # convert logits to probabilities (which are of shape (1, seq_len))\r\n",
        "    dist_per_token = torch.distributions.Bernoulli(logits=logits)\r\n",
        "    probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to(dist_per_token.probs.device) \r\n",
        "\r\n",
        "    # Compute average probability per cell, aggregating over tokens.\r\n",
        "    # Dictionary maps coordinates to a list of one or more probabilities\r\n",
        "    coords_to_probs = collections.defaultdict(list)\r\n",
        "    prev_answers = {}\r\n",
        "    for i, p in enumerate(probabilities.squeeze().tolist()):\r\n",
        "      segment_id = token_type_ids_example[:,0].tolist()[i]\r\n",
        "      col = token_type_ids_example[:,1].tolist()[i] - 1\r\n",
        "      row = token_type_ids_example[:,2].tolist()[i] - 1\r\n",
        "      if col >= 0 and row >= 0 and segment_id == 1:\r\n",
        "        coords_to_probs[(col, row)].append(p)\r\n",
        "\r\n",
        "    # Next, map cell coordinates to 1 or 0 (depending on whether the mean prob of all cell tokens is > 0.5)\r\n",
        "    coords_to_answer = {}\r\n",
        "    for key in coords_to_probs:\r\n",
        "      coords_to_answer[key] = np.array(coords_to_probs[key]).mean() > 0.5\r\n",
        "    prev_answers[idx+1] = coords_to_answer\r\n",
        "    \r\n",
        "  logits_batch = torch.cat(tuple(all_logits), 0)\r\n",
        "  \r\n",
        "  return logits_batch"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jflxDE_BfVg9"
      },
      "source": [
        "data = {'Actors': [\"Brad Pitt\", \"Leonardo Di Caprio\", \"George Clooney\"], \n",
        "        'Age': [\"56\", \"45\", \"59\"],\n",
        "        'Number of movies': [\"87\", \"53\", \"69\"],\n",
        "        'Date of birth': [\"7 february 1967\", \"10 june 1996\", \"28 november 1967\"]}\n",
        "queries = [\"How many movies has George Clooney played in?\", \"How old is he?\", \"What's his date of birth?\"]\n",
        "\n",
        "table = pd.DataFrame.from_dict(data)\n",
        "\n",
        "inputs = tokenizer(table=table, queries=queries, padding='max_length', return_tensors=\"pt\")\n",
        "logits = compute_prediction_sequence(model, inputs, device)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k_a_Y-rDq__o"
      },
      "source": [
        "Finally, we can use the handy `convert_logits_to_predictions` function of `TapasTokenizer` to convert the logits into predicted coordinates, and print out the result:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5fAcNOVsqoVD"
      },
      "source": [
        "predicted_answer_coordinates, = tokenizer.convert_logits_to_predictions(inputs, logits.cpu().detach())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QP4AHMxFujhV",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 254
        },
        "outputId": "aed2fc99-957b-4b9f-e804-b426a80de8df"
      },
      "source": [
        "# handy helper function in case inference on Pandas dataframe\n",
        "answers = []\n",
        "for coordinates in predicted_answer_coordinates:\n",
        "  if len(coordinates) == 1:\n",
        "    # only a single cell:\n",
        "    answers.append(table.iat[coordinates[0]])\n",
        "  else:\n",
        "    # multiple cells\n",
        "    cell_values = []\n",
        "    for coordinate in coordinates:\n",
        "      cell_values.append(table.iat[coordinate])\n",
        "    answers.append(\", \".join(cell_values))\n",
        "\n",
        "display(table)\n",
        "print(\"\")\n",
        "for query, answer in zip(queries, answers):\n",
        "  print(query)\n",
        "  print(\"Predicted answer: \" + answer)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>Actors</th>\n",
              "      <th>Age</th>\n",
              "      <th>Number of movies</th>\n",
              "      <th>Date of birth</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>Brad Pitt</td>\n",
              "      <td>56</td>\n",
              "      <td>87</td>\n",
              "      <td>7 february 1967</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>Leonardo Di Caprio</td>\n",
              "      <td>45</td>\n",
              "      <td>53</td>\n",
              "      <td>10 june 1996</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>George Clooney</td>\n",
              "      <td>59</td>\n",
              "      <td>69</td>\n",
              "      <td>28 november 1967</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "               Actors Age Number of movies     Date of birth\n",
              "0           Brad Pitt  56               87   7 february 1967\n",
              "1  Leonardo Di Caprio  45               53      10 june 1996\n",
              "2      George Clooney  59               69  28 november 1967"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "\n",
            "How many movies has George Clooney played in?\n",
            "Predicted answer: Brad Pitt\n",
            "How old is he?\n",
            "Predicted answer: Brad Pitt, Leonardo Di Caprio, George Clooney\n",
            "What's his date of birth?\n",
            "Predicted answer: Brad Pitt, Leonardo Di Caprio, George Clooney\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6L0KBaPjG7uj"
      },
      "source": [
        "Note that the results here are not correct, that's obvious since we only trained on 28 examples and tested it on an entire different example. In reality, you should train on the entire dataset. The result of this is the `google/tapas-base-finetuned-sqa` checkpoint."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y4S-TIGSvqhZ"
      },
      "source": [
        "## Legacy\r\n",
        "\r\n",
        "The code below was considered during the creation of this tutorial, but eventually not used."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ox1ZECiJ5vSD"
      },
      "source": [
        "# grouped = data.groupby(data.position)\r\n",
        "# test = grouped.get_group(0)\r\n",
        "# test.index"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "L0IuO6vivrw_"
      },
      "source": [
        "def custom_collate_fn(data):\r\n",
        "  \"\"\"\r\n",
        "  A custom collate function to batch input_ids, attention_mask, token_type_ids and so on of different batch sizes.\r\n",
        "  \r\n",
        "  Args:\r\n",
        "    data: \r\n",
        "      a list of dictionaries (each dictionary is what the __getitem__ method of TableDataset returns)\r\n",
        "  \"\"\"\r\n",
        "  result = {}\r\n",
        "  for k in data[0].keys():\r\n",
        "      result[k] = torch.cat([x[k] for x in data], dim=0)\r\n",
        "\r\n",
        "  return result\r\n",
        "\r\n",
        "class TableDataset(torch.utils.data.Dataset):\r\n",
        "    def __init__(self, df, tokenizer):\r\n",
        "        self.df = df\r\n",
        "        self.tokenizer = tokenizer\r\n",
        "\r\n",
        "    def __getitem__(self, idx):\r\n",
        "        item = self.df.iloc[idx]\r\n",
        "        table = pd.read_csv(table_csv_path + item.table_file[9:]).astype(str) # TapasTokenizer expects the table data to be text only\r\n",
        "        if item.position != 0:\r\n",
        "          # use the previous table-question pair \r\n",
        "          previous_item = self.df.iloc[idx-1]\r\n",
        "          encoding = self.tokenizer(table=table, \r\n",
        "                                    queries=[previous_item.question, item.question], \r\n",
        "                                    answer_coordinates=[previous_item.answer_coordinates, item.answer_coordinates], \r\n",
        "                                    answer_text=[previous_item.answer_text, item.answer_text],\r\n",
        "                                    padding=\"max_length\",\r\n",
        "                                    truncation=True,\r\n",
        "                                    return_tensors=\"pt\"\r\n",
        "          )\r\n",
        "          # remove the batch dimension which the tokenizer adds \r\n",
        "          encoding = {key: val[-1] for key, val in encoding.items()}\r\n",
        "          #encoding = {key: val.squeeze(0) for key, val in encoding.items()}\r\n",
        "        else:\r\n",
        "          # this means it's the first table-question pair in a sequence\r\n",
        "          encoding = self.tokenizer(table=table, \r\n",
        "                                    queries=item.question, \r\n",
        "                                    answer_coordinates=item.answer_coordinates, \r\n",
        "                                    answer_text=item.answer_text,\r\n",
        "                                    padding=\"max_length\",\r\n",
        "                                    truncation=True,\r\n",
        "                                    return_tensors=\"pt\"\r\n",
        "          )\r\n",
        "        return encoding\r\n",
        "\r\n",
        "    def __len__(self):\r\n",
        "        return len(self.df)\r\n",
        "\r\n",
        "train_dataset = TableDataset(df=grouped, tokenizer=tokenizer)\r\n",
        "train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, collate_fn=custom_collate_fn)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}